from copy import deepcopy
from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.training.train_state import TrainState
from ml_collections import ConfigDict

from algos.model import Scalar, update_target_network
from utilities.jax_utils import mse_loss, next_rng, value_and_multi_grad
from utilities.utils import prefix_metrics


class ConservativeSAC(object):

  @staticmethod
  def get_default_config(updates=None):
    config = ConfigDict()
    config.nstep = 1
    config.discount = 0.99
    config.alpha_multiplier = 1.0
    config.use_automatic_entropy_tuning = True
    config.backup_entropy = False
    config.target_entropy = 0.0
    config.encoder_lr = 3e-4
    config.policy_lr = 3e-4
    config.qf_lr = 3e-4
    config.optimizer_type = 'adam'
    config.soft_target_update_rate = 5e-3
    config.use_cql = True
    config.cql_n_actions = 10
    config.cql_importance_sample = True
    config.cql_lagrange = False
    config.cql_target_action_gap = 1.0
    config.cql_temp = 1.0
    config.cql_min_q_weight = 5.0
    config.cql_max_target_backup = False
    config.cql_clip_diff_min = -np.inf
    config.cql_clip_diff_max = np.inf
    config.bc_mode = 'mse'  # 'mle'
    config.bc_weight = 0.
    config.res_hidden_size = 1024
    config.encoder_blocks = 1
    config.head_blocks = 1

    if updates is not None:
      config.update(ConfigDict(updates).copy_and_resolve_references())
    return config

  def __init__(self, config, encoder, policy, qf, decoupled_q=False):
    self.config = self.get_default_config(config)
    self.decoupled_q = decoupled_q
    self.policy = policy
    self.qf = qf
    self.encoder = encoder
    self.observation_dim = policy.input_size
    self.embedding_dim = policy.embedding_dim
    self.action_dim = policy.action_dim

    self._train_states = {}

    optimizer_class = {
      'adam': optax.adam,
      'sgd': optax.sgd,
    }[self.config.optimizer_type]

    encoder_params = self.encoder.init(
      next_rng(), jnp.zeros((10, self.policy.observation_dim))
    )
    self._train_states['encoder'] = TrainState.create(
      params=encoder_params,
      tx=optimizer_class(self.config.encoder_lr),
      apply_fn=None
    )

    policy_params = self.policy.init(
      next_rng(), next_rng(), jnp.zeros((10, self.embedding_dim))
    )
    self._train_states['policy'] = TrainState.create(
      params=policy_params,
      tx=optimizer_class(self.config.policy_lr),
      apply_fn=None
    )

    qf1_params = self.qf.init(
      next_rng(), jnp.zeros((10, self.embedding_dim)),
      jnp.zeros((10, self.action_dim))
    )
    self._train_states['qf1'] = TrainState.create(
      params=qf1_params,
      tx=optimizer_class(self.config.qf_lr),
      apply_fn=None,
    )
    qf2_params = self.qf.init(
      next_rng(), jnp.zeros((10, self.embedding_dim)),
      jnp.zeros((10, self.action_dim))
    )
    self._train_states['qf2'] = TrainState.create(
      params=qf2_params,
      tx=optimizer_class(self.config.qf_lr),
      apply_fn=None,
    )
    self._target_qf_params = deepcopy({'qf1': qf1_params, 'qf2': qf2_params})

    model_keys = ['policy', 'qf1', 'qf2', 'encoder']

    if self.config.use_automatic_entropy_tuning:
      self.log_alpha = Scalar(0.0)
      self._train_states['log_alpha'] = TrainState.create(
        params=self.log_alpha.init(next_rng()),
        tx=optimizer_class(self.config.policy_lr),
        apply_fn=None
      )
      model_keys.append('log_alpha')

    if self.config.cql_lagrange:
      self.log_alpha_prime = Scalar(1.0)
      self._train_states['log_alpha_prime'] = TrainState.create(
        params=self.log_alpha_prime.init(next_rng()),
        tx=optimizer_class(self.config.qf_lr),
        apply_fn=None
      )
      model_keys.append('log_alpha_prime')

    self._model_keys = tuple(model_keys)
    self._total_steps = 0

  def train(self, batch, weight_eval, weight_improve, weight_constraint, bc=False):
    self._total_steps += 1
    self._train_states, self._target_qf_params, metrics = self._train_step(
      self._train_states, self._target_qf_params, next_rng(), batch, weight_eval, weight_improve, weight_constraint, bc
    )
    return metrics